model_config:
  class_name: tensorflow_asr.models.ctc.transformer>Transformer
  config:
    name: transformer
    speech_config:
      sample_rate: 16000
      frame_ms: 25
      stride_ms: 10
      nfft: 512
      num_feature_bins: 80
      feature_type: log_mel_spectrogram
      augmentation_config:
        feature_augment:
          time_masking:
            prob: 1.0
            num_masks: 5
            mask_factor: -1
            p_upperbound: 0.05
          freq_masking:
            prob: 1.0
            num_masks: 2
            mask_factor: 27
    encoder_subsampling:
      type: conv2d
      filters: [512, 512]
      kernels: [3, 3]
      strides: [2, 2]
      paddings: ["causal", "causal"]
      norms: ["batch", "batch"]
      activations: ["relu", "relu"]
    encoder_dropout: 0.1
    encoder_residual_factor: 1.0
    encoder_norm_position: post
    encoder_dmodel: 512
    encoder_dff: 1024
    encoder_num_blocks: 6
    encoder_head_size: 128
    encoder_num_heads: 4
    encoder_mha_type: mha
    encoder_interleave_relpe: True
    encoder_use_attention_causal_mask: False
    encoder_use_attention_auto_mask: True
    encoder_pwffn_activation: relu
    encoder_memory_length: null
    encoder_history_size: 64 # frames = 4 * chunk_size
    encoder_chunk_size: 16 # frames
    blank: 0
    vocab_size: {{decoder_config.vocabsize}}
    kernel_regularizer:
      class_name: l2
      config:
        l2: 1e-6

learning_config:
  optimizer_config:
    class_name: Adam
    config:
      learning_rate:
        class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
        config:
          dmodel: 512
          warmup_steps: 10000
          max_lr: null
          min_lr: null
      beta_1: 0.9
      beta_2: 0.98
      epsilon: 1e-9

  gwn_config:
    predict_net_step: 0
    predict_net_stddev: 0.075

  gradn_config: null

  batch_size: 8
  ga_steps: 4
  num_epochs: 450

  callbacks:
    - class_name: tensorflow_asr.callbacks>TerminateOnNaN
      config: {}
    - class_name: tensorflow_asr.callbacks>ModelCheckpoint
      config:
        filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
        save_best_only: False
        save_weights_only: True
        save_freq: epoch
    - class_name: tensorflow_asr.callbacks>TensorBoard
      config:
        log_dir: {{modeldir}}/tensorboard
        histogram_freq: 0
        write_graph: False
        write_images: False
        write_steps_per_second: False
        update_freq: batch
        profile_batch: 0
    - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore
      config:
        model_handle: {{kaggle_model_handle}}
        model_dir: {{modeldir}}
        save_freq: epoch
